from functools import wraps
from typing import Any, Dict, Optional
import json
import time
import hashlib
from datetime import datetime, timedelta

class MemoryCache:
    """内存缓存实现"""
    
    def __init__(self):
        self._cache: Dict[str, Dict[str, Any]] = {}
        self._access_times: Dict[str, float] = {}
        self.max_size = 1000  # 最大缓存条目数
        self.default_ttl = 300  # 默认5分钟过期
    
    def _generate_key(self, prefix: str, *args, **kwargs) -> str:
        """生成缓存键"""
        key_data = f"{prefix}:{args}:{sorted(kwargs.items())}"
        return hashlib.md5(key_data.encode()).hexdigest()
    
    def _is_expired(self, key: str) -> bool:
        """检查缓存是否过期"""
        if key not in self._cache:
            return True
        
        cache_data = self._cache[key]
        if cache_data['expires_at'] < time.time():
            self._remove(key)
            return True
        return False
    
    def _remove(self, key: str):
        """移除缓存项"""
        self._cache.pop(key, None)
        self._access_times.pop(key, None)
    
    def _evict_if_needed(self):
        """如果需要，清理最少使用的缓存项"""
        if len(self._cache) >= self.max_size:
            # 移除最少访问的项
            oldest_key = min(self._access_times.keys(), 
                           key=lambda k: self._access_times[k])
            self._remove(oldest_key)
    
    def get(self, key: str) -> Optional[Any]:
        """获取缓存值"""
        if self._is_expired(key):
            return None
        
        self._access_times[key] = time.time()
        return self._cache[key]['value']
    
    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
        """设置缓存值"""
        self._evict_if_needed()
        
        expires_at = time.time() + (ttl or self.default_ttl)
        self._cache[key] = {
            'value': value,
            'expires_at': expires_at,
            'created_at': time.time()
        }
        self._access_times[key] = time.time()
    
    def delete(self, key: str) -> bool:
        """删除缓存项"""
        if key in self._cache:
            self._remove(key)
            return True
        return False
    
    def clear(self) -> None:
        """清空所有缓存"""
        self._cache.clear()
        self._access_times.clear()
    
    def get_stats(self) -> Dict[str, Any]:
        """获取缓存统计信息"""
        now = time.time()
        expired_count = sum(1 for data in self._cache.values() 
                          if data['expires_at'] < now)
        
        return {
            'total_items': len(self._cache),
            'expired_items': expired_count,
            'memory_usage_mb': len(str(self._cache)) / 1024 / 1024,
            'max_size': self.max_size
        }

# 全局缓存实例
cache = MemoryCache()

def cached(ttl: int = 300, key_prefix: str = "default"):
    """缓存装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 生成缓存键
            cache_key = cache._generate_key(f"{key_prefix}:{func.__name__}", *args, **kwargs)
            
            # 尝试从缓存获取
            cached_result = cache.get(cache_key)
            if cached_result is not None:
                return cached_result
            
            # 执行函数并缓存结果
            result = func(*args, **kwargs)
            cache.set(cache_key, result, ttl)
            return result
        
        # 添加清除缓存的方法
        wrapper.clear_cache = lambda: cache.clear()
        wrapper.cache_stats = lambda: cache.get_stats()
        
        return wrapper
    return decorator

def invalidate_cache_pattern(pattern: str):
    """根据模式清除缓存"""
    keys_to_remove = []
    for key in cache._cache.keys():
        if pattern in key:
            keys_to_remove.append(key)
    
    for key in keys_to_remove:
        cache.delete(key)
    
    return len(keys_to_remove)

class CacheManager:
    """缓存管理器"""
    
    @staticmethod
    def clear_questions_cache():
        """清除题目相关缓存"""
        return invalidate_cache_pattern("questions")
    
    @staticmethod
    def clear_submissions_cache():
        """清除提交相关缓存"""
        return invalidate_cache_pattern("submissions")
    
    @staticmethod
    def clear_analysis_cache():
        """清除分析相关缓存"""
        return invalidate_cache_pattern("analysis")
    
    @staticmethod
    def get_cache_info():
        """获取缓存信息"""
        return cache.get_stats()